
import pandas as pd
import numpy as np
import jieba
import re
from sklearn.model_selection import train_test_split
# from gensim.models.word2vec import Word2Vec
from tensorflow import keras
from nlu_model.ptm.word2vector import Word2vector
from nlu_model.util.pkl_impl import save_pkl, load_pkl


def load_data(path):
    sentence = []
    label = []
    with open(path) as f:
        for line in f:
            ll = line.strip().split("\t")
            sentence.append([ll[1], ll[2]])
            label.append(int(ll[3]))
    return sentence, label

def single2double(sentence):
    sentence_1 = []
    sentence_2 = []
    for pair in sentence:
        sentence_1.append(pair[0])
        sentence_2.append(pair[1])
    return sentence_1, sentence_2

sentence, label = load_data("./data/sim/atec_sim/atec_nlp_sim_train_all.csv")

x_train, x_test, y_train, y_test = train_test_split(sentence, label, test_size=0.2,random_state=666)

x_train_1, x_train_2 = single2double(x_train)
x_test_1, x_test_2 = single2double(x_test)

word2vector = Word2vector()

word2vector.load("./data/ptm/shopping_reviews/w2v_word2idx2020100601.pkl",
                 "./data/ptm/shopping_reviews/w2v_model_metric_2020100601.pkl", 
                 "./data/ptm/shopping_reviews/w2v_model_conf_2020100601.pkl")

x_train_1 = word2vector.batch2idx([list(jieba.cut(i)) for i  in x_train_1])
x_train_2 = word2vector.batch2idx([list(jieba.cut(i)) for i  in x_train_2])
x_test_1 = word2vector.batch2idx([list(jieba.cut(i)) for i  in x_test_1])
x_test_2 = word2vector.batch2idx([list(jieba.cut(i)) for i  in x_test_2])

# data_preprocess
x_train_1 = keras.preprocessing.sequence.pad_sequences(x_train_1,
                                                    value=0,
                                                    padding='post',
                                                    maxlen=50)

x_train_2 = keras.preprocessing.sequence.pad_sequences(x_train_2,
                                                    value=0,
                                                    padding='post',
                                                    maxlen=50)

x_test_1 = keras.preprocessing.sequence.pad_sequences(x_test_1,
                                                    value=0,
                                                    padding='post',
                                                    maxlen=50)

x_test_2 = keras.preprocessing.sequence.pad_sequences(x_test_2,
                                                    value=0,
                                                    padding='post',
                                                    maxlen=50)



inputs_1 = keras.layers.Input(shape=(50,))
inputs_2 = keras.layers.Input(shape=(50,))
embedding_layer = keras.layers.Embedding(output_dim = 300, # 词向量 长度（100）
                            input_dim = len(word2vector.embedding_weights), # 字典长度
                            weights=[word2vector.get_np_weights()], # 重点：预训练的词向量系数
                            input_length=50, # 每句话的 最大长度（必须padding） 
                            trainable=True # 是否在 训练的过程中 更新词向量
                            )
embedding_seq_1 = embedding_layer(inputs_1)
embedding_seq_2 = embedding_layer(inputs_2)
flatten_1 = keras.layers.Flatten()(embedding_seq_1)
flatten_2 = keras.layers.Flatten()(embedding_seq_2)
concat = keras.layers.concatenate([flatten_1, flatten_2])
output = keras.layers.Dense(32, activation='relu')(concat)
pred = keras.layers.Dense(units=1, activation='sigmoid')(output)

model = keras.models.Model(inputs=[inputs_1, inputs_2], outputs=pred)
model.summary()
model.compile(loss="binary_crossentropy", optimizer="adam", metrics=['accuracy'])

history = model.fit([x_train_1, x_train_2], y_train, batch_size=64,
                    epochs=1,
                    validation_data=([x_test_1, x_test_2], y_test),
                    verbose=1)
results = model.evaluate([x_test_1, x_test_2], y_test)
print(results)

model_vec1 = keras.models.Model(inputs=model.get_layer('input_2').input, outputs=model.get_layer('flatten_1').output)
sentence = "花呗怎么还"
sentence = list(jieba.cut(sentence))
sentence_id = np.zeros(50)
for idx in range(len(sentence)):
    if sentence[idx] in word2vector.word2idx_dic:
        sentence_id[idx] = word2vector.word2idx_dic[sentence[idx]]
    else:
        sentence_id[idx] = len(word2vector.word2idx_dic)-1
if len(sentence_id) > 50:
    sentence_id = sentence_id[:50]
while True:
    if len(sentence_id) < 50:
        sentence_id.append(0)
    else:
        break
print(len(sentence_id))
print(sentence_id)
a = model_vec1.predict(np.array([sentence_id]))
# print(model_vec1.predict(np.array([sentence_id])))

history = model.fit([x_train_1, x_train_2], y_train, batch_size=64,
                    epochs=1,
                    validation_data=([x_test_1, x_test_2], y_test),
                    verbose=1)
b = model_vec1.predict(np.array([sentence_id]))
print(a==b)
